from tensorflow.keras.callbacks import TensorBoard
import tensorflow as tf
from tensorflow_model_optimization.sparsity import keras as sparsity
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import LearningRateScheduler
import numpy as np

import tempfile
import zipfile
import os
from tool import common
import threading


def pruning_thread():
    
    def pruning(): 

        model = tf.keras.models.load_model('E:\\workspace python\\tf2\\LWHD2\\src\\model\\tic_0305.h5')
        
        epoch_num = 12
        end_step = np.ceil(1.0 * 6845*4/64).astype(np.int32) * epoch_num
        print(end_step)
        
        new_pruning_params = {
              'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                           final_sparsity=0.90,
                                                           begin_step=0,
                                                           end_step=end_step,
                                                           frequency=100)
        }
        
        new_pruned_model = sparsity.prune_low_magnitude(model, **new_pruning_params)
        new_pruned_model.summary()
        
        learning_rate = 1e-3 * 0.8;
        
        def scheduler(epoch):
            if epoch < epoch_num * 0.4:
                return learning_rate
            if epoch < epoch_num * 0.7:
                return learning_rate * 0.1
            return learning_rate * 0.05
        
        
        sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
        change_lr = LearningRateScheduler(scheduler)
        new_pruned_model.compile(sgd, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        
        
        logdir = os.path.join("tflog")
        print('Writing training logs to ' + logdir)
        
        
        callbacks = [
            sparsity.UpdatePruningStep(),
            sparsity.PruningSummaries(log_dir=logdir, profile_batch=0),
            change_lr
        ]
        
        new_pruned_model.fit(common.dataProcess_tic_type.get_dataset_train(),
                  epochs=epoch_num,
                  verbose=1,
                  callbacks=callbacks)
        
        score = new_pruned_model.evaluate(common.dataProcess_tic_type.get_dataset_test())
        
        print('Test loss:', score[0])
        print('Test accuracy:', score[1])
        
        final_model = sparsity.strip_pruning(new_pruned_model)
        final_model.summary()
        
        
        new_pruned_keras_file = 'E:\\workspace python\\tf2\\LWHD2\\src\\model\\pruned_0305.h5'
        tf.keras.models.save_model(final_model, new_pruned_keras_file, 
                                include_optimizer=False)
        
        print("Size of the pruned model before compression: %.2f Mb" 
              % (os.path.getsize(new_pruned_keras_file) / float(2**20)))
        
    th = threading.Thread(target=pruning)
    th.start()